# import
import numpy as np
import pandas as pd
from sklearn.model_selection import KFold
from sklearn.utils import resample
from numba import njit

import os
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
from util.models import *
from util.misc import *

## mat file
from scipy.io import loadmat


@njit
def train_CPT(X_train, Y_train, param_grid):
    """Find the optimal parameter via grid search

    Parameters
    ----------
    X_train : ndarray(float, ndim=2)
        Array containing [z1, z2, p1]
    Y_train : ndarray(float, ndim=1)
        Array containing ce
    param_grid : ndarray(float, ndim=1)
        Array containing all possible parameter conbinations

    Returns
    -------
    best_param : ndarray(float, ndim=1)
        The optimal parameter found via grid search
    """
    z1 = X_train[:, 0]
    z2 = X_train[:, 1]
    p1 = X_train[:, 2]
    errors = train_CPT_main(z1, z2, p1, Y_train, param_grid)
    opt_idx = np.argmin(errors)
    best_param = param_grid[opt_idx]
    return best_param


@njit
def train_CPT_main(z1, z2, p1, Y_train, param_grid):
    """Compute the errors for each parameter

    Parameters
    ----------
    z1, z2, p1 : ndarray(float, ndim=1)
        Covariates in the training data
    Y_train : ndarray(float, ndim=1)
    param_grid : ndarray(float, ndim=1)

    Returns
    -------
    errors : ndarray(float, ndim=1)
        Array containing errors of the model under different parameters
    """
    errors = np.zeros(param_grid.shape[0])
    for i, param in enumerate(param_grid):
        error = np.mean((Y_train - pred_CPT(z1, z2, p1, param))**2)
        errors[i] = error
    return errors


@njit
def train_DA(X_train, Y_train, param_grid):
    """Find the optimal parameter via grid search

    Parameters
    ----------
    X_train : ndarray(float, ndim=2)
        Array containing [z1, z2, p1]
    Y_train : ndarray(float, ndim=1)
        Array containing ce
    param_grid : ndarray(float, ndim=1)
        Array containing all possible parameter conbinations

    Returns
    -------
    best_param : ndarray(float, ndim=1)
        The optimal parameter found via grid search
    """
    z1 = X_train[:, 0]
    z2 = X_train[:, 1]
    p1 = X_train[:, 2]
    errors = train_DA_main(z1, z2, p1, Y_train, param_grid)
    opt_idx = np.argmin(errors)
    best_param = param_grid[opt_idx]
    return best_param


@njit
def train_DA_main(z1, z2, p1, Y_train, param_grid):
    """Compute the errors for each parameter

    Parameters
    ----------
    z1, z2, p1 : ndarray(float, ndim=1)
        Covariates in the training data
    Y_train : ndarray(float, ndim=1)
    param_grid : ndarray(float, ndim=1)

    Returns
    -------
    errors : ndarray(float, ndim=1)
        Array containing errors of the model under different parameters
    """
    errors = np.zeros(param_grid.shape[0])
    for i, param in enumerate(param_grid):
        error = np.mean((Y_train - pred_DA(z1, z2, p1, param))**2)
        errors[i] = error
    return errors


@njit
def train_best_pred(cov_idx_train, Y_train, cov_list):
    """Return the best prediction model
    The model is expressed via ndarray:
    For index i of covariate, best_pred[i] is the conditional mean of ce's in the training data
    """
    best_pred = np.zeros_like(cov_list)
    for i_cov, cov in enumerate(cov_list):
        flag = (cov_idx_train==cov)
        Y_temp = Y_train[flag]
        best_pred[i_cov] = np.mean(Y_temp)
    return best_pred


@njit
def compute_best_SE(best_pred, cov_idx_test, Y_test, cov_list):
    Y_pred = np.zeros_like(Y_test)
    for i_cov, cov in enumerate(cov_list):
        flag = cov_idx_test==cov
        pred_val = best_pred[i_cov]
        Y_pred[flag] = pred_val
    best_SE = (Y_test - Y_pred)**2
    best_MSE = np.mean(best_SE)
    return best_MSE, best_SE


@njit
def compute_model_SE(pred, X_test, Y_test, model_param):
    z1 = X_test[:, 0]
    z2 = X_test[:, 1]
    p1 = X_test[:, 2]

    Y_pred = pred(z1, z2, p1, model_param)
    model_SE = (Y_test - Y_pred)**2
    model_MSE = np.mean(model_SE)
    return model_MSE, model_SE


@njit
def compute_var(idx_test, model_SE, base_SE, best_SE):
    # clustering
    test_sample_size = idx_test.shape[0]
    model_delta = model_SE - best_SE
    base_delta = base_SE - best_SE
    # np.var with optional arguments are not supported by numba
    # Thus, we need to use np.mean to compute variance
    # model_var = np.var(model_delta, ddof=1)
    # base_var = np.var(base_delta, ddof=1)
    # covar = np.cov(model_delta, base_delta, ddof=1)[0,1]
    model_var = np.sum((model_delta - np.mean(model_delta))**2)/(test_sample_size-1)
    base_var = np.sum((base_delta - np.mean(base_delta))**2)/(test_sample_size-1)
    covar = np.sum((model_delta - np.mean(model_delta))*(base_delta - np.mean(base_delta)))/(test_sample_size-1)

    return model_var, base_var, covar


@njit
def one_fold(X, Y, idx, cov_idx, cov_list,
             train_idx, test_idx,
             pred, train_model, base_param, param_grid):
    """Main part of cross validation
    Given the partition of the dataset, compute the errors, variances, and optimal parameter

    Parameters
    ----------
    X, Y, idx, cov_idx, cov_list : ndarrays
        See `cross validation`
    train_idx : ndarray(int, ndim=1)
        Array containing 0-1 values.
        The i-th element is 1 iff the i-th data is included in the train data.
    test_idx : ndarray(int, ndim=1)
        Array containing 0-1 values.
        The i-th element is 1 iff the i-th data is included in the test data.
    pred, train_model, base_param, param_grid : See `cross validation`

    Returns
    -------
    model_MSE : float
    base_MSE : float
    best_MSE : float
    model_var : float
    base_var : float
    covar : float
    best_param : ndarray(float, ndim=1)
    """
    X_train = X[train_idx]
    Y_train = Y[train_idx]
    X_test = X[test_idx]
    Y_test = Y[test_idx]
    cov_idx_train = cov_idx[train_idx]
    cov_idx_test = cov_idx[test_idx]
    idx_test = idx[test_idx]

    # Train the model via grid search
    best_param = train_model(X_train, Y_train, param_grid)

    # Construct best predictor
    best_pred = train_best_pred(cov_idx_train, Y_train, cov_list)

    # compute errors
    model_MSE, model_SE = compute_model_SE(pred, X_test, Y_test, best_param)
    base_MSE, base_SE = compute_model_SE(pred, X_test, Y_test, base_param)
    best_MSE, best_SE = compute_best_SE(best_pred, cov_idx_test, Y_test, cov_list)
    model_var, base_var, covar = compute_var(idx_test, model_SE, base_SE, best_SE)

    return model_MSE, base_MSE, best_MSE, model_var, base_var, covar, best_param


def preprocess(df):
    """Prepocess a pandas dataframe to prepare the inputs for cross_validation

    Parameters
    ----------
    df : pandas.Dataframe
        the dataframe should contain the following columns:
        z1 : the amount of high prize
        z2 : the amount of low prize
        p1 : the probability of getting higher prize
        ce : the reported certainty equivalent

    Returns
    -------
    X : ndarray(float, ndim=2)
        Array containing z1, z2, and p1
    Y : ndarray(float, ndim=1)
        Array containing reported certainty equivalents
    idx : ndarray(int, ndim=1)
        Array containing indices
    cov_idx : ndarray(int, ndim=1)
        Array containing indices for possible covariate combinations of (z1, z2, p1)'s
    cov_list : ndarray(int, ndim=1)
        Array containing the set of all indices for possible covariates
        (i.e., all possible indices for lotteries)

    Example
    -------
    If X[10] = (100, 0, 0.5), Y[10] = 30, subjext_idx[10] = 3, cov_idx[10] = 8, this means
    subject 3's certainty equivalent to lottery (100,0;0.5) is 30, and we call the lottery lottery 8.
    """
    X = df[['z1', 'z2', 'p1']].values
    Y = df['ce'].values
    # cov_idx = df['lottery'].values
    idx = np.arange(df.shape[0])

    covariates = np.unique(X, axis=0)
    sample_size = X.shape[0]
    cov_idx = np.zeros(sample_size, dtype=int) -1
    for l, cov in enumerate(covariates):
        for i in range(sample_size):
            if np.allclose(X[i], cov):
                cov_idx[i] = l

    cov_list = np.unique(cov_idx)
    return X, Y, idx, cov_idx, cov_list


def cross_validation(X, Y, idx, cov_idx, cov_list, param_grid,
                     pred=pred_CPT, train_model=train_CPT, base_param=np.array([1,1,1]),
                     n_splits=10, print_option=True, bootstrap_option=False):
    """Conduct K-fold cross validation
    Inputs are generated via `preprocess` and `cartesian`.

    Parameters
    ----------
    X : ndarray(float, ndim=2)
        Array containing z1, z2, and p1
    Y : ndarray(float, ndim=1)
        Array containing reported certainty equivalents
    idx : ndarray(int, ndim=1)
        Array containing indices
    cov_idx : ndarray(int, ndim=1)
        Array containing indices for possible covariate combinations of (z1, z2, p1)'s
    cov_list : ndarray(int, ndim=1)
        Array containing the set of all indices for possible covariates
        (i.e., all possible indices for lotteries)
    param_grid : ndarray(float, ndim=1)
        Array containing all possible parameter combinations
    pred : function, optional
        prediction model, by default pred_CPT
    train_model : function, optional
        method of training the model, by default train_CPT
    train_best_pred : function, optional
        method of obtaining the ideal, best prediction model, by default train_best_pred
    base_param : list, optional
        parameter for the baseline model, by default [1,1,1]
    n_splits : int, optional
        the number of folds for cross validation, by default 10
    print_option : bool, optional
        If true, the result is printed out, by default True

    Returns
    -------
    completeness : float
    stderr : float
        Analytical SE
    model_best_params : ndarray(float, ndim=1)
        Array containing the optimal parameter (computed as the average of optimal parameters for K folds)

    Notes
    -----
    The subject ids for K-fold CV should be splitted in advance. The index data should be stored in `fold` folder.
    """
    # initialization
    model_MSEs = []
    base_MSEs = []
    best_MSEs = []
    model_vars = []
    base_vars = []
    covars = []
    model_best_params = []

    if bootstrap_option:
        # NB: this branch is not used in this code
        # Used in completeness.py (which conducts clustering)
        kf = KFold(n_splits, shuffle=True)
        for train_idx_clustered, test_idx_clustered in kf.split(idx_clustered):
            # NB: kf.split returns indices for the array, not the subject indices
            # train_subject_idx contains the subject ids for those who are included in the training data
            train_subject_idx = idx_clustered[train_idx_clustered]
            test_subject_idx = idx_clustered[test_idx_clustered]
            train_idx = np.isin(subject_idx, train_subject_idx)
            test_idx = np.isin(subject_idx, test_subject_idx)

            model_MSE, base_MSE, best_MSE, model_var, base_var, covar, best_param =\
                one_fold(X, Y, subject_idx, cov_idx, cov_list,
                        train_idx, test_idx, test_subject_idx,
                        pred, train_model, base_param, param_grid)

            # add values to lists
            model_MSEs.append(model_MSE)
            base_MSEs.append(base_MSE)
            best_MSEs.append(best_MSE)
            model_vars.append(model_var)
            base_vars.append(base_var)
            covars.append(covar)
            model_best_params.append(best_param)

    else:
        base = os.path.dirname(os.path.abspath(__file__))
        for fold in range(n_splits):
            file_name = os.path.normpath(os.path.join(base, './fold_no_cluster/train_idx_{}.csv'.format(fold)))
            train_idx = np.loadtxt(file_name, dtype='int32')
            file_name = os.path.normpath(os.path.join(base, './fold_no_cluster/test_idx_{}.csv'.format(fold)))
            test_idx = np.loadtxt(file_name, dtype='int32')
            train_idx = np.isin(idx, train_idx)
            test_idx = np.isin(idx, test_idx)

            model_MSE, base_MSE, best_MSE, model_var, base_var, covar, best_param =\
                one_fold(X, Y, idx, cov_idx, cov_list,
                        train_idx, test_idx,
                        pred, train_model, base_param, param_grid)

            # add values to lists
            model_MSEs.append(model_MSE)
            base_MSEs.append(base_MSE)
            best_MSEs.append(best_MSE)
            model_vars.append(model_var)
            base_vars.append(base_var)
            covars.append(covar)
            model_best_params.append(best_param)

    # wrap up results
    model_MSE = np.mean(model_MSEs)
    base_MSE = np.mean(base_MSEs)
    best_MSE = np.mean(best_MSEs)
    model_var = np.mean(model_vars)
    base_var = np.mean(base_vars)
    covar = np.mean(covars)
    model_best_params = np.array(model_best_params)
    model_best_param = np.mean(model_best_params, axis=0)

    completeness = 1 - (model_MSE - best_MSE)/(base_MSE - best_MSE)
    # print(base_MSE, model_MSE, best_MSE) # for debugging
    var_kappa = (model_var - 2*completeness*covar + (completeness**2)*base_var)/((base_MSE - best_MSE)**2)
    sample_size = Y.shape[0]
    var = var_kappa/sample_size
    stderr = np.sqrt(var)

    if print_option:
        print('sample_size: {}'.format(sample_size))
        print('completeness: {}'.format(completeness))
        print('stderr (analytical): {}'.format(stderr))
        print('model_best_param: {}'.format(model_best_param))
        print('model_MSEs: {}'.format(model_MSEs))
        print('best_MSEs: {}'.format(best_MSEs))
        print('base_MSEs: {}'.format(base_MSEs))

    return completeness, stderr, model_best_param


def bootstrap_se(X, Y, subject_idx, cov_idx, cov_list, param_grid,
                 base_param=np.array([1,1,1]), pred=pred_CPT, train_model=train_CPT,
                 bs_sample_size=50, seed=20220306):
    idx_clustered = np.sort(np.unique(subject_idx))
    completeness_list = np.zeros(bs_sample_size)

    np.random.seed(seed)
    for iter in range(bs_sample_size):
        id_re = resample(idx_clustered) # may be rewritten only via numpy -- np.random.choices
        sample_size = X.shape[0]

        # the sample size of a resampled dataset may differ from the original one
        # prepare larger arrays
        X_re = np.zeros((sample_size*3, X.shape[1]))
        Y_re = np.zeros(sample_size*3)
        subject_idx_re = np.zeros(sample_size*3)
        cov_idx_re = np.zeros(sample_size*3)
        ptr = 0
        for id in id_re:
            flag = (subject_idx == id)
            num_elem = flag.sum()
            X_re[ptr:ptr+num_elem] = X[flag]
            Y_re[ptr:ptr+num_elem] = Y[flag]
            subject_idx_re[ptr:ptr+num_elem] = subject_idx[flag]
            cov_idx_re[ptr:ptr+num_elem] = cov_idx[flag]
            ptr += num_elem

        # remove the redundant rows
        X_re = X_re[0:ptr]
        Y_re = Y_re[0:ptr]
        subject_idx_re = subject_idx_re[0:ptr]
        cov_idx_re = cov_idx_re[0:ptr]

        completeness, _, __ =\
            cross_validation(X_re, Y_re, subject_idx_re, cov_idx_re, cov_list, param_grid,
                             base_param=base_param, pred=pred, train_model=train_model,
                             n_splits=10, print_option=False, bootstrap_option=True)
        completeness_list[iter] = completeness

        if _%100 == 0:
            print('{}+1 epochs done'.format((_//100)*100))

    stderr = np.sqrt(np.var(completeness_list)) # divided by N or N-1?
    print('stderr (bootstrapped): {}'.format(stderr))
    return stderr

def bootstrap_one_iter(X, Y, subject_idx, cov_idx, cov_list, param_grid,
                       base_param=np.array([1,1,1]), pred=pred_CPT, train_model=train_CPT):
    idx_clustered = np.sort(np.unique(subject_idx))

    id_re = resample(idx_clustered)
    sample_size = X.shape[0]

    # the sample size of a resampled dataset may differ from the original one
    # prepare larger arrays
    X_re = np.zeros((sample_size*3, X.shape[1]))
    Y_re = np.zeros(sample_size*3)
    subject_idx_re = np.zeros(sample_size*3)
    cov_idx_re = np.zeros(sample_size*3)
    ptr = 0
    for id in id_re:
        flag = (subject_idx == id)
        num_elem = flag.sum()
        X_re[ptr:ptr+num_elem] = X[flag]
        Y_re[ptr:ptr+num_elem] = Y[flag]
        subject_idx_re[ptr:ptr+num_elem] = subject_idx[flag]
        cov_idx_re[ptr:ptr+num_elem] = cov_idx[flag]
        ptr += num_elem

    # remove the redundant rows
    X_re = X_re[0:ptr]
    Y_re = Y_re[0:ptr]
    subject_idx_re = subject_idx_re[0:ptr]
    cov_idx_re = cov_idx_re[0:ptr]

    completeness, _, __ =\
        cross_validation(X_re, Y_re, subject_idx_re, cov_idx_re, cov_list, param_grid,
                            base_param=base_param, pred=pred, train_model=train_model,
                            n_splits=10, print_option=False, bootstrap_option=True)

    return completeness


if __name__ == '__main__':
    import time

    # seed
    seed = 20220219
    np.random.seed(seed)
    print('seed={}'.format(seed))

    # load dataset
    base = os.path.dirname(os.path.abspath(__file__))
    file_name = '../data/Bruhin_et_al_2010.mat'
    file_path = os.path.normpath(os.path.join(base, file_name))
    matdata = loadmat(file_path)

    df_full = pd.DataFrame()
    for key in matdata.keys():
        if not '__' in key:
            df_full[key] = matdata[key].flatten()

    ## drop lotteries over losses
    df = df_full[df_full['z2'] >= 0]

    # preprocessing
    X, Y, subject_idx, cov_idx, cov_list = preprocess(df)


    # main
    ## CPT
    start_time = time.perf_counter()

    ### choose param grids
    grid_size_alpha = 0.01
    grid_size_gamma = 0.01
    grid_size_delta = 0.1
    alpha_grid = np.arange(grid_size_alpha, 1+grid_size_alpha, grid_size_alpha)
    gamma_grid = np.arange(grid_size_gamma, 1+grid_size_gamma, grid_size_gamma)
    delta_grid = np.arange(grid_size_delta, 5+grid_size_delta, grid_size_delta)
    param_grid = cartesian_product(alpha_grid, gamma_grid, delta_grid)

    print('CPT')
    completeness, stderr, model_best_params =\
        cross_validation(X, Y, subject_idx, cov_idx, cov_list, param_grid)
    # se_bootstrap = bootstrap_se(X, Y, subject_idx, cov_idx, cov_list, param_grid, bs_sample_size=5)

    # print
    # print('stderr (bootstrapped): {}'.format(se_bootstrap))

    end_time = time.perf_counter()

    elapsed_time = end_time - start_time
    hour = elapsed_time//3600
    elapsed_time = elapsed_time - hour*3600
    minute = elapsed_time//60
    elapsed_time = elapsed_time - minute*60
    second = elapsed_time

    print('')
    print('process time: {}:{}:{}'.format(int(hour), int(minute), int(second)))



    ## DA
    start_time = time.perf_counter()

    alpha_grid_size = 0.01
    eta_grid_size = 0.01
    alpha_ubd = 1
    eta_lbd = -1
    eta_ubd = 5
    alpha_grid = np.arange(alpha_grid_size, alpha_ubd+alpha_grid_size, alpha_grid_size)
    eta_grid = np.arange(eta_lbd, eta_ubd+eta_grid_size, eta_grid_size)
    param_grid = cartesian_product(alpha_grid, eta_grid)

    print('')
    print('DA')
    _, _, _ =\
    cross_validation(X, Y, subject_idx, cov_idx, cov_list, param_grid,
                     pred=pred_DA, train_model=train_DA,
                     base_param=np.array([1,0]), n_splits=10, print_option=True)

    end_time = time.perf_counter()

    elapsed_time = end_time - start_time
    hour = elapsed_time//3600
    elapsed_time = elapsed_time - hour*3600
    minute = elapsed_time//60
    elapsed_time = elapsed_time - minute*60
    second = elapsed_time

    print('')
    print('process time: {}:{}:{}'.format(int(hour), int(minute), int(second)))
